import os
import re # re模块提供了各种各样的正则表达式方法
import pin
import assembly as ASM

dirname = os.path.dirname(__file__)

inputfile = os.path.join(dirname, 'program.asm')
outputfile = os.path.join(dirname, 'program.bin')

annotation = re.compile(r"(.*?);.*") # 这个正则表达式的目的是捕获分号;之前的所有内容

codes = []
marks = {}

OP2 = {
    'MOV': ASM.MOV,
    'ADD': ASM.ADD,
    'SUB': ASM.SUB,
    'CMP': ASM.CMP,
    'AND': ASM.AND,
    'OR': ASM.OR,
    'XOR': ASM.XOR
}

OP1 = {
    'INC': ASM.INC,
    'DEC': ASM.DEC,
    'NOT': ASM.NOT,
    'JMP': ASM.JMP
}

OP0 = {
    'NOP': ASM.NOP,
    'HLT': ASM.HLT,
}

OP2SET = set(OP2.values())
OP1SET = set(OP1.values())
OP0SET = set(OP0.values())

REGISTERS = {
    "A": pin.A,
    "B": pin.B,
    "C": pin.C,
    "D": pin.D,
}

class Code(object):
    TYPE_CODE = 1
    TYPE_LABEL = 2
    
    def __init__(self, number, source: str):
        self.numer = number # 行号
        self.source = source.upper() # 源代码
        self.op = None
        self.dst = None
        self.src = None
        self.type = self.TYPE_CODE
        self.index = 0
        self.prepare_source() # 调用预处理源代码
    
    def get_op(self):
        if self.op in OP2:
            return OP2[self.op]
        if self.op in OP1:
            return OP1[self.op]
        if self.op in OP0:
            return OP0[self.op]
        raise SyntaxError(self)

    def get_am(self, addr): # 获取目的操作数和源操作数
        global marks
        
        if not addr:
            return None, None
        if addr in marks: # increase在marks[]中
            return pin.AM_INS, marks[addr].index * 3 # [ir, dst, src]
        if addr in REGISTERS: # 如果是寄存器，返回寄存器编码。示例就是A寄存器
            return pin.AM_REG, REGISTERS[addr]
        if re.match(r'^[0-9]+$', addr): # 如果是数字，返回立即数。示例就是5
            return pin.AM_INS, int(addr)
        if re.match(r'^0X[0-9A-F]+$', addr): # 如果是十六进制数，返回十六进制立即数
            return pin.AM_INS, int(addr, 16)
        match = re.match(r'^\[([0-9]+)\]$', addr)
        if match:
            return pin.AM_DIR, int(match.group(1))
        match = re.match(r'^\[(0X[0-9A-F]+)\]$', addr)
        if match:
            return pin.AM_DIR, int(match.group(1), 16)
        match = re.match(r'^\[(.+)\]$', addr)
        if match and match.group(1) in REGISTERS:
            return pin.AM_RAM, REGISTERS[match.group(1)]
        
        raise SyntaxError(self)
        
    def prepare_source(self): # 预处理汇编代码，以 MOV A, 5 和 JMP increase 举例
        if self.source.endswith(':'): # 判断标记
            self.type = self.TYPE_LABEL
            self.name = self.source.strip(':')
            return
        
        tup = self.source.split(',') # 用逗号分隔
        if len(tup) > 2:
            raise SyntaxError(self)
        if len(tup) == 2:
            self.src = tup[1].strip() # 5赋值给源操作数
        
        tup = re.split(r" +", tup[0]) # 正则表达式，将tup[0]字符串中的一个或多个连续空格作为分隔符，将字符串拆分成多个部分，并返回一个包含拆分后的所有部分的列表。将 MOV A拆分成了MOV和A
        if len(tup) > 2:
            raise SyntaxError(self)
        if len(tup) == 2:
            self.dst = tup[1].strip() # A赋值给了目的操作数
        
        self.op = tup[0].strip() # MOV赋值给指令
    
    def compile_code(self):
        # 指令IR ==> op + amd + ams
        # JMP increase，一地址指令时：01xx xx[bb]
        op = self.get_op() # 拿到JMP的指令定义
        amd, dst = self.get_am(self.dst)
        ams, src = self.get_am(self.src)
        
        if src is not None and (amd, ams) not in ASM.INSTRUCTIONS[2][op]:
            raise SyntaxError(self)
        if src is None and dst and amd not in ASM.INSTRUCTIONS[1][op]:
            raise SyntaxError(self)
        if src is None and dst is None and op not in ASM.INSTRUCTIONS[0]:
            raise SyntaxError(self)
        
        amd = amd or 0
        ams = ams or 0
        dst = dst or 0
        src = src or 0
        
        if op in OP2SET:
            ir = op | (amd << 2) | ams
        elif op in OP1SET:
            ir = op | amd
        else:
            ir = op
        
        return [ir, dst, src]
    
    def __repr__(self):
        return f'[{self.numer}] - {self.source}'

class SyntaxError(Exception):
    def __init__(self,code: Code, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.code = code

def compile_program():
    global codes
    global marks
    
    with open(inputfile, encoding='utf8') as file: # 打开汇编源码
        lines = file.readlines()
    
    for index, line in enumerate(lines):
        source = line.strip() # 将两端的空格去掉
        if ';' in source: # 将;后面的去掉
            match = annotation.match(source) # 使用之前定义的正则表达式annotation来匹配分号之前的内容
            source = match.group(1)
        if not source: # 检查source是否为空或只包含空白字符（例如空格、制表符、换行符等）
            continue
        code = Code(index + 1, source) # 传入行号和每行的汇编代码。创建了Code类的实例，调用了__init__构造
        codes.append(code)

    code = Code(index + 2, 'HLT')
    codes.append(code)
    
    result = []
    
    current = None
    for var in range(len(codes) - 1, -1, -1): # 从后往前
        code = codes[var]
        if code.type == Code.TYPE_CODE: # 如果这行代码是普通汇编代码
            current = code
            result.insert(0, code)
            continue
        if code.type == Code.TYPE_LABEL: # 如果这行代码是标记
            marks[code.name] = current # 记录标记的下面一行的代码，如 marks[increase] = current
            continue
        raise SyntaxError(code)
    
    for index, var in enumerate(result):
        var.index = index
    
    with open(outputfile, 'wb') as file:
        for code in result:
            values = code.compile_code()
            for value in values:
                result = value.to_bytes(1, byteorder='little')
                file.write(result)

def main():
    try:
        compile_program()
    except SyntaxError as e:
        print(f'Syntax error at {e.code}')
        return

    print('compile program.asm finished!!!')
        
if __name__ == '__main__':
    main()